import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize

np.set_printoptions(precision=3, suppress=True)

def output(k, inputs, nu, rho, beta):
    return (k**rho + inputs[0]**rho + (beta*inputs[1])**rho)**(nu/rho)

def profit(k, inputs, nu, rho, beta, prices):
    return output(k, inputs, nu, rho, beta) - k*prices[0] - np.dot(inputs, prices[1:])

def SR_solution(k, nu, rho, beta, prices):
    #capital is treated as sunk
    SR_profit = lambda x : -profit(k, x, nu, rho, beta, prices) - k*prices[0]
    res = minimize(SR_profit, np.ones(2), bounds=[(0,None), (0,None)])
    #print(res)
    inputs = res.x
    Y      = output(k, res.x, nu, rho, beta)
    Pi     = -res.fun
    return np.concatenate([[beta, k], inputs, [Y, inputs[1]/Y, Pi]])

def MR_solution(k, nu, rho, beta, prices, c0, c2, SR_solution):
    MR_profit = lambda x : -(profit(x[0], x[1:], nu, rho, beta, prices) 
                                - c0 - c2*(x[0]-k)**2 + k*prices[0]) #Rebate existing capital

    res = minimize(MR_profit, np.ones(3), bounds=[(0, None), (0,None), (0,None)])
    inputs = res.x
    Y      = output(res.x[0], res.x[1:], nu, rho, beta)
    Pi     = -res.fun
    
    if Pi > SR_solution[-2]:
        return np.concatenate([[beta], inputs, [Y, inputs[2]/Y, Pi]])
    else:
        return SR_solution[:-1]


def LR_solution(nu, rho, beta, prices, discount):
    LR_profit = lambda x : -(profit(x[0], x[1:], nu, rho, beta, prices)

    #res = minimize(LR_profit, np.ones(3), bounds=[(0, None), (0,None), (0,None)], method='Nelder-Mead')
    res = minimize(LR_profit, np.ones(3), bounds=[(0, None), (0,None), (0,None)])
    inputs = res.x
    Y      = output(res.x[0], res.x[1:], nu, rho, beta)
    Pi     = -res.fun
    return np.concatenate([[beta], inputs, [Y, inputs[2]/Y, Pi+inputs[0]*prices[0] ]])



#[beta, K, L, E, Y, E/Y, Pi, ENTER]
def run_decomposition(P1b, P2b, a, nu, rho, FC, SV, c0, c2, Betas):
    P1 = a*P1b #Multiply base price times a 
    P2 = a*P2b #Multiply base price times a
    
    T = len(Betas)
    entrants0 = np.empty((T, 7 + 1))
    for i,b in enumerate(Betas):
        entrants0[i,:-1] = LR_solution(nu, rho, b, P1)

    entrants0[:,-1]  = entrants0[:, -2] > FC

    incumbents1 = np.empty_like(entrants0)
    capital_adj1 = np.empty_like(entrants0)
    capital_flex = np.empty_like(entrants0)
    entrants1   = np.empty_like(entrants0)
    for i, b in enumerate(Betas):
        incumbents1[i,:-1]  = SR_solution(entrants0[i,1], nu, rho, b, P2)
        capital_adj1[i,:-1] = MR_solution(entrants0[i,1], nu, rho, b, P2, c0, c2, incumbents1[i]) 
        capital_flex[i,:-1] = MR_solution(entrants0[i,1], nu, rho, b, P2, 0, 0, incumbents1[i]) 
        entrants1[i, :-1]   = LR_solution(nu, rho, b, P2)

    #Check Profits
    incumbents1[:,-1]  = incumbents1[:,-2] > SV
    capital_adj1[:,-1] = capital_adj1[:,-2] > SV
    capital_flex[:,-1] = capital_flex[:,-2] > SV
    entrants1[:,-1]    = entrants1[:, -2] > FC
    

    operating = [entrants0[:,-1]==1, 
                 entrants1[:,-1]==1, 
                (entrants0[:,-1]==1) & (incumbents1[:,-1]==1),
                (entrants0[:,-1]==1) & (capital_adj1[:,-1]==1),
                (entrants0[:,-1]==1) & (capital_flex[:,-1]==1)]
    print()   
    print(np.sum(operating[0]))
    print(np.sum(operating[1]))
    print(np.sum(operating[2]))
    print(np.sum(operating[3]))
    print(np.sum(operating[4]))

    original      = np.mean(entrants0[operating[0],:-1], axis=0)
    fixed_inc     = np.mean(incumbents1[operating[0],:-1], axis=0)
    surviving_inc = np.mean(incumbents1[operating[2],:-1], axis=0)
    part_inc      = np.mean(capital_adj1[operating[3],:-1], axis=0)
    flex_inc      = np.mean(capital_flex[operating[2],:-1], axis=0)
    new_entrants  = np.mean(entrants1[operating[1],:-1], axis=0)

    """
    print("Energy Price: ", P1[-1], '->', P2[-1])
    print('original:            ', original)
    print('fixed incumbents:    ', fixed_inc)
    print('surviving incumbents:', surviving_inc)
    print('flex incumbents:     ', flex_inc)
    print('new entrants:        ', new_entrants)
    print()

    total = new_entrants[5]-original[5]
    print("Change in E/Y:           ", total)
    print("original -> fixed:       ", round(100*(fixed_inc[5]-original[5])/total,1))
    print("fixed -> surviving:      ", round(100*(surviving_inc[5] - fixed_inc[5])/total,1))
    print("surviving -> flexible:   ", round(100*(flex_inc[5] - surviving_inc[5])/total,1))
    print("flexible -> new entrants:", round(100*(new_entrants[5] - flex_inc[5])/total,1))
    """
    #print("Note: These changes should sum to 1, modulo rounding")
    #print()
    #print("Note: Profits are not directly comparable (capital costs are counted differently")
    #Set "v" here to determine what (mean) variable gets outputted for each model type. 
    #v = 0 -> energy productivity
    #v = 1,2,3 -> capital stock, labor inputs, electricity input
    #v = 4 -> output
    #v = 5 -> electricity/output
    #v = 6 -> profit
    v = 5 #Energy intensity
    return np.array([original[v], fixed_inc[v], surviving_inc[v], part_inc[v], flex_inc[v], new_entrants[v]])

FC = 50 #Entry Cost
SV = 45   #Scrap Value (They exit if can't make profit)

c0  = 1    #fixed cost of investment
c2  = 1e-8 #Quadratic cost of investment

nu  = 0.7 #Returns to scale
rho = -3.0 #Elasticity of substitution (negative -> complements)
a   = 0.025 #This scales prices to keep things balanced. 

T = 500 #Number of sample points for distribution omega_e
Betas  = np.linspace(0.1, 2, T) #energy productivity omega_e uniformly distributed 0.1 and 2

fig, ax = plt.subplots()
#ax.set_title(r"Complements ($\rho=-0.5$)")

#price [of capital, of labor, of energy]
P2 = np.array([1,1,1])
prices = [np.array([1,1,p]) for p in np.linspace(0.2,1,25)]
result = np.empty((len(prices), 1+6))
for i, P1 in enumerate(prices):
    result[i,0]  = P1[2]
    result[i,1:] = run_decomposition(P1, P2, a, nu, rho, FC, SV, c0, c2, Betas)

#result[:,2:-1] = (result[:,2:-1]-result[:,-1:])/result[:,-1:]
result[:,2:-1] = (result[:,2:-1] )#/result[:,-1:]
ax.plot(result[:,0], result[:,3], ls='-.',  label='static with exit')
ax.plot(result[:,0], result[:,4], ls='--', label='partial capital and exit')
ax.plot(result[:,0], result[:,5], ls='-',  label='flexible capital and exit')
#ax.plot(result[:,0], np.zeros_like(result[:,5]), label='new entrants', color='k', ls='--')
ax.plot(result[:,0], np.ones_like(result[:,5]), label='new entrants', color='k', ls='--')
ax.set_xlabel("Initial Energy Price")
ax.set_ylabel("Current Energy Intensity \n Relative to Entrant")
ax.legend()
plt.show()
